"""Logic related to worker compliance checks, which seek to ensure workers do not produce malicious responses."""

import datetime
from typing import cast

import fastapi
import sqlmodel
from loguru import logger
from oasst_inference_server import database, deps, models, worker_utils
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from sqlalchemy.sql.functions import random as sql_random
from sqlmodel import not_, or_


async def find_compliance_work_request_message(
    session: database.AsyncSession, worker_config: inference.WorkerConfig, worker_id: str
) -> models.DbMessage | None:
    """
    Find a suitable assistant message to carry out a worker compliance check for the given worker. Such a message must
    have been generated by a different worker, but one with the same compatibility hash as the given worker.
    """
    compat_hash = worker_config.compat_hash
    query = (
        sqlmodel.select(models.DbMessage)
        .where(
            models.DbMessage.role == "assistant",
            models.DbMessage.state == inference.MessageState.complete,
            models.DbMessage.worker_compat_hash == compat_hash,
            models.DbMessage.worker_id != worker_id,
        )
        .order_by(sql_random())
    )
    message = (await session.exec(query)).first()
    return message


async def should_do_compliance_check(session: database.AsyncSession, worker_id: str) -> bool:
    """
    Check whether we should carry out a compliance check for the given worker, based on time since last check.
    Trusted workers are excluded.
    """
    worker = await worker_utils.get_worker(worker_id, session)
    if worker.trusted:
        return False
    if worker.in_compliance_check:
        return False
    if worker.next_compliance_check is None:
        return True
    if worker.next_compliance_check < datetime.datetime.utcnow():
        return True
    return False


async def run_compliance_check(websocket: fastapi.WebSocket, worker_id: str, worker_config: inference.WorkerConfig):
    """
    Run a compliance check for the given worker:
    - Find a suitable compliance check assistant message
    - Task the worker with generating a response with the same context
    - Compare the respons against the existing completed message
    - Update the database with the outcome
    """
    async with deps.manual_create_session() as session:
        try:
            worker = await worker_utils.get_worker(worker_id, session)
            if worker.in_compliance_check:
                logger.info(f"Worker {worker.id} is already in compliance check")
                return
            worker.in_compliance_check_since = datetime.datetime.utcnow()
        finally:
            await session.commit()

    logger.info(f"Running compliance check for worker {worker_id}")

    async with deps.manual_create_session(autoflush=False) as session:
        compliance_check = models.DbWorkerComplianceCheck(worker_id=worker_id)

        try:
            message = await find_compliance_work_request_message(session, worker_config, worker_id)
            if message is None:
                logger.warning(
                    f"Could not find message for compliance check for worker {worker_id} with config {worker_config}"
                )
                return

            compliance_check.compare_worker_id = message.worker_id
            compliance_work_request = await worker_utils.build_work_request(session, message.id)

            logger.info(f"Found work request for compliance check for worker {worker_id}: {compliance_work_request}")
            await worker_utils.send_worker_request(websocket, compliance_work_request)
            response = None
            while True:
                response = await worker_utils.receive_worker_response(websocket)
                if response.response_type == "error":
                    compliance_check.responded = True
                    compliance_check.error = response.error
                    logger.warning(f"Worker {worker_id} errored during compliance check: {response.error}")
                    return
                if response.response_type == "generated_text":
                    break
            if response is None:
                logger.warning(f"Worker {worker_id} did not respond to compliance check")
                return
            compliance_check.responded = True
            response = cast(inference.GeneratedTextResponse, response)
            passes = response.text == message.content
            compliance_check.passed = passes
            logger.info(f"Worker {worker_id} passed compliance check: {passes}")

        finally:
            compliance_check.end_time = datetime.datetime.utcnow()
            session.add(compliance_check)
            worker = await worker_utils.get_worker(worker_id, session)
            worker.next_compliance_check = datetime.datetime.utcnow() + datetime.timedelta(
                seconds=settings.compliance_check_interval
            )
            worker.in_compliance_check_since = None
            logger.info(f"set next compliance check for worker {worker_id} to {worker.next_compliance_check}")
            await session.commit()
            await session.flush()


async def maybe_do_compliance_check(websocket, worker_id, worker_config, worker_session_id):
    async with deps.manual_create_session() as session:
        should_check = await should_do_compliance_check(session, worker_id)
    if should_check:
        logger.info(f"Worker {worker_id} needs compliance check")
        try:
            await worker_utils.update_worker_session_status(
                worker_session_id, worker_utils.WorkerSessionStatus.compliance_check
            )
            await run_compliance_check(websocket, worker_id, worker_config)
        finally:
            await worker_utils.update_worker_session_status(worker_session_id, worker_utils.WorkerSessionStatus.waiting)


async def compute_worker_compliance_score(worker_id: str) -> float:
    """
    Compute a float between 0 and 1 (inclusive) representing the compliance score of the worker.
    Workers are rewarded for passing compliance checks, and penalised for failing to respond to a check, erroring during a check, or failing a check.
    In-progress checks are ignored.
    """
    async with deps.manual_create_session() as session:
        query = sqlmodel.select(models.DbWorkerComplianceCheck).where(
            or_(
                models.DbWorkerComplianceCheck.worker_id == worker_id,
                models.DbWorkerComplianceCheck.compare_worker_id == worker_id,
            ),
            not_(models.DbWorkerComplianceCheck.end_time.is_(None)),
        )
        worker_checks: list[models.DbWorkerComplianceCheck] = (await session.exec(query)).all()

        # Rudimentary scoring algorithm, we may want to add weightings or other factors
        total_count = len(worker_checks)

        checked = [c for c in worker_checks if c.worker_id == worker_id]
        compared = [c for c in worker_checks if c.compare_worker_id == worker_id]

        pass_count = sum(1 for _ in filter(lambda c: c.passed, checked))
        error_count = sum(1 for _ in filter(lambda c: c.error is not None, checked))
        no_response_count = sum(1 for _ in filter(lambda c: not c.responded, checked))

        compare_fail_count = sum(1 for _ in filter(lambda c: not c.passed, compared))
        fail_count = len(checked) - pass_count - error_count - no_response_count

        return (fail_count + compare_fail_count) / total_count
